{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(193270:281473817470512,MainProcess):2021-03-16-15:29:32.819.845 [mindspore/_check_version.py:207] MindSpore version 1.1.1 and \"te\" wheel package version 1.0 does not match, reference to the match info on: https://www.mindspore.cn/install\n",
      "MindSpore version 1.1.1 and \"topi\" wheel package version 0.6.0 does not match, reference to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(193270:281473817470512,MainProcess):2021-03-16-15:29:33.378.527 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "from mindspore import Tensor, nn, Model, context\n",
    "from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor\n",
    "from mindspore.communication.management import init, get_rank\n",
    "from mindspore.context import ParallelMode\n",
    "from mindspore.train.serialization import load_param_into_net, load_checkpoint\n",
    "\n",
    "from src.preprocess import convert_to_mindrecord\n",
    "from src.dataset import create_dataset\n",
    "from src.seq2seq import Seq2Seq, WithLossCell, InferCell\n",
    "from src.config import cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target='Ascend', device_id=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train = create_dataset(cfg.dataset_path, cfg.batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "network = Seq2Seq(cfg)\n",
    "network = WithLossCell(network, cfg)\n",
    "optimizer = nn.Adam(network.trainable_params(), learning_rate=cfg.learning_rate, beta1=0.9, beta2=0.98)\n",
    "model = Model(network, optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 125, loss is 2.804515\n",
      "epoch time: 32208.512 ms, per step time: 257.668 ms\n",
      "epoch: 2 step: 125, loss is 1.963039\n",
      "epoch time: 11227.136 ms, per step time: 89.817 ms\n",
      "epoch: 3 step: 125, loss is 1.8751457\n",
      "epoch time: 11207.574 ms, per step time: 89.661 ms\n",
      "epoch: 4 step: 125, loss is 2.0917926\n",
      "epoch time: 11235.453 ms, per step time: 89.884 ms\n",
      "epoch: 5 step: 125, loss is 1.5626856\n",
      "epoch time: 11257.191 ms, per step time: 90.058 ms\n",
      "epoch: 6 step: 125, loss is 1.0996865\n",
      "epoch time: 11264.321 ms, per step time: 90.115 ms\n",
      "epoch: 7 step: 125, loss is 0.9826399\n",
      "epoch time: 11222.325 ms, per step time: 89.779 ms\n",
      "epoch: 8 step: 125, loss is 0.61559135\n",
      "epoch time: 11283.613 ms, per step time: 90.269 ms\n",
      "epoch: 9 step: 125, loss is 0.34942892\n",
      "epoch time: 11223.944 ms, per step time: 89.792 ms\n",
      "epoch: 10 step: 125, loss is 0.32617155\n",
      "epoch time: 11203.418 ms, per step time: 89.627 ms\n",
      "epoch: 11 step: 125, loss is 0.25858104\n",
      "epoch time: 11256.972 ms, per step time: 90.056 ms\n",
      "epoch: 12 step: 125, loss is 0.25984508\n",
      "epoch time: 11243.845 ms, per step time: 89.951 ms\n",
      "epoch: 13 step: 125, loss is 0.13721837\n",
      "epoch time: 11246.996 ms, per step time: 89.976 ms\n",
      "epoch: 14 step: 125, loss is 0.12634051\n",
      "epoch time: 11269.224 ms, per step time: 90.154 ms\n",
      "epoch: 15 step: 125, loss is 0.07898388\n",
      "epoch time: 11256.715 ms, per step time: 90.054 ms\n"
     ]
    }
   ],
   "source": [
    "loss_cb = LossMonitor()\n",
    "config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)\n",
    "ckpoint_cb = ModelCheckpoint(prefix=\"gru\", directory=cfg.ckpt_save_path, config=config_ck)\n",
    "time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())\n",
    "callbacks = [time_cb, ckpoint_cb, loss_cb]\n",
    "\n",
    "model.train(cfg.num_epochs, ds_train, callbacks=callbacks, dataset_sink_mode=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rank = 0\n",
    "device_num = 1\n",
    "ds_eval= create_dataset(cfg.dataset_path, cfg.eval_batch_size, is_training=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "network = Seq2Seq(cfg,is_train=False)\n",
    "network = InferCell(network, cfg)\n",
    "network.set_train(False)\n",
    "parameter_dict = load_checkpoint(cfg.checkpoint_path)\n",
    "load_param_into_net(network, parameter_dict)\n",
    "model = Model(network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "English: do you like snow ? \n",
      "expect Chinese: 你喜欢雪吗？\n",
      "predict Chinese: 你喜欢雪吗？\n",
      " \n",
      "English: i can see tom . \n",
      "expect Chinese: 我看得见汤姆。\n",
      "predict Chinese: 我看得见汤姆。\n",
      " \n",
      "English: stay sharp . \n",
      "expect Chinese: 保持警惕。\n",
      "predict Chinese: 保持警惕。\n",
      " \n",
      "English: stop meddling . \n",
      "expect Chinese: 别再插手。\n",
      "predict Chinese: 别再插手。\n",
      " \n",
      "English: tom is a magician . \n",
      "expect Chinese: 汤姆是魔法师。\n",
      "predict Chinese: 汤姆是魔法师。\n",
      " \n",
      "English: i am very sad . \n",
      "expect Chinese: 我很难过。\n",
      "predict Chinese: 我很难过。\n",
      " \n",
      "English: i m very happy . \n",
      "expect Chinese: 我很快乐。\n",
      "predict Chinese: 我很快乐。\n",
      " \n",
      "English: don t let tom die . \n",
      "expect Chinese: 别让汤姆死了。\n",
      "predict Chinese: 别让汤姆死了。\n",
      " \n",
      "English: time flies . \n",
      "expect Chinese: 时光飞逝。\n",
      "predict Chinese: 时光飞逝。\n",
      " \n",
      "English: let s turn back . \n",
      "expect Chinese: 我们掉头吧！\n",
      "predict Chinese: 我们掉头吧！\n",
      " \n",
      "English: he caught a cold . \n",
      "expect Chinese: 他感冒了。\n",
      "predict Chinese: 他着凉了。\n",
      " \n",
      "English: is it all there ? \n",
      "expect Chinese: 全都在那里吗？\n",
      "predict Chinese: 全都在那里吗？\n",
      " \n",
      "English: take me home . \n",
      "expect Chinese: 带我回家。\n",
      "predict Chinese: 带我回家。\n",
      " \n",
      "English: anyone can do that . \n",
      "expect Chinese: 任何人都可以做到。\n",
      "predict Chinese: 任何人都可以做到。\n",
      " \n",
      "English: balls are round . \n",
      "expect Chinese: 球是圆的。\n",
      "predict Chinese: 球是圆的。\n",
      " \n",
      "English: back off . \n",
      "expect Chinese: 往后退点。\n",
      "predict Chinese: 往后退点。\n",
      " \n",
      "English: where am i ? \n",
      "expect Chinese: 我在哪里？\n",
      "predict Chinese: 我在哪里？\n",
      " \n",
      "English: i hope so . \n",
      "expect Chinese: 我希望如此。\n",
      "predict Chinese: 我希望如此。\n",
      " \n",
      "English: we can begin . \n",
      "expect Chinese: 我们能开始。\n",
      "predict Chinese: 我们能开始。\n",
      " \n",
      "English: tom would accept . \n",
      "expect Chinese: 汤姆会接受。\n",
      "predict Chinese: 汤姆会接受。\n",
      " \n"
     ]
    }
   ],
   "source": [
    "with open(os.path.join(cfg.dataset_path,\"en_vocab.txt\"), 'r', encoding='utf-8') as f:\n",
    "    data = f.read()\n",
    "en_vocab = list(data.split('\\n'))\n",
    "\n",
    "with open(os.path.join(cfg.dataset_path,\"ch_vocab.txt\"), 'r', encoding='utf-8') as f:\n",
    "    data = f.read()\n",
    "ch_vocab = list(data.split('\\n'))\n",
    "\n",
    "for data in ds_eval.create_dict_iterator():\n",
    "    en_data=''\n",
    "    ch_data=''\n",
    "    for x in data['encoder_data'][0].asnumpy():\n",
    "        if x == 0:\n",
    "            break\n",
    "        en_data += en_vocab[x]\n",
    "        en_data += ' '\n",
    "    for x in data['decoder_data'][0].asnumpy():\n",
    "        if x == 0:\n",
    "            break\n",
    "        if x == 1:\n",
    "            continue\n",
    "        ch_data += ch_vocab[x]\n",
    "    output = network(data['encoder_data'],data['decoder_data'])\n",
    "    print('English:', en_data)\n",
    "    print('expect Chinese:', ch_data)\n",
    "    out =''\n",
    "    for x in output[0].asnumpy():\n",
    "        if x == 0:\n",
    "            break\n",
    "        out += ch_vocab[x]\n",
    "    print('predict Chinese:', out)\n",
    "    print(' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
